"""
QUTIP simulations of the coherent motional excitation with the running lattice

Parts based on Gregor Hegi code provided in Appendix C of Hegi, G.: 'Towards A Non-Destructive Single Molecular Ion State Readout' PhD thesis, University of Basel (2021)

@author: Mikolaj Roguski
"""

from qutip import *
import numpy as np
import math
import matplotlib.pyplot as plt
import sys
import os 
import re

from PlottingAndSavingPackage_v3 import *

# Constants pre-define
pi = np.pi
hbar = 1.05457173e-34
amu = 1.660539040e-27

def create_trap_Hamiltonian(m1,m2,v1,aIP_op,aOP_op):
    
    # Mass ratio
    mu = m1/m2 

    m1 = m1*amu
    m2 = m2*amu
    
    # Calulate mode frequencies 
    vIP = v1 * np.sqrt (1. + mu - np. sqrt (1. - mu + mu**2))
    vOP = v1 * np.sqrt (1. + mu + np. sqrt (1. - mu + mu**2))
    # print('vIP', vIP)
    # print('vOP', vOP)
    
    # Convert frequencies to angular frequencies 
    w1 = 2 * pi * v1
    wIP = 2 * pi * vIP
    wOP = 2 * pi * vOP
    
    # Define spread of the motional WF for each mode  operators for each mode 
    xIP_O = np.sqrt (hbar / (2*m1*wIP*1.0e6)) *1.0e6   # [um]
    xOP_O = np.sqrt (hbar / (2*m1*wOP*1.0e6)) * 1.0e6  # [um]
    
    # Define position operators for each mode
    # Note: if lattice acts on both ions, it is required to add 
    # relative displacement between position operators 
    xIP_op = xIP_O * (aIP_op.dag() + aIP_op)
    xOP_op = xOP_O * (aOP_op.dag() + aOP_op)
    
    # Individual position of ions 
    u = 1 /np.sqrt(1 + (1 - mu - np.sqrt(1 - mu + mu**2)) ** 2/mu)
    v = np.sqrt(1 - u**2)
    x1 = v * xIP_op - u * xOP_op
    x2 = np.sqrt(mu) * (u * xIP_op + v * xOP_op)
    
    # Harmonic potential for two ions Hamiltonian 
    H0 = wIP * (aIP_op.dag()*aIP_op + 0.5) + wOP * (aOP_op.dag()*aOP_op + 0.5)
    
    return H0, x1, x2, vIP, vOP


def create_lattice_interaction_hamiltonian(L_lattice, x_act, Eac0, m_max_cos_sin):
    
    # x is the position of the ion interacting with the lattice 
    kvec = 2 * pi / L_lattice
    
    k = kvec
    x = x_act
        
    # Compute the Taylor series expansion explicitly for cos(2*k*x)
    H1_cos = 2 * Eac0 * sum(((-1) ** m * (2 * k * x) ** (2 * m) / math.factorial(2 * m)) for m in range(m_max_cos_sin // 2 + 1))
    H1_sin = 2 * Eac0 *sum(((-1) ** m * (2 * k * x) ** (2 * m + 1) / math.factorial(2 * m + 1)) for m in range(m_max_cos_sin // 2 + 1))
    H1_const = 2 * Eac0

    return H1_cos, H1_sin, H1_const


def run_mesolve(H0, H1_const, H1_cos, H1_sin, mot0, tlist, v_lattice, dm0, aIP_op, aOP_op, check_norm_flag=True):
    print('(Running calculations:) Start solving')
    options = {
        "atol": 1e-9,            # -9 Absolute tolerance
        "rtol": 1e-8,            # -8 Relative tolerance
        "nsteps":1e6,      # Maximum number of steps
        "store_states": True,    # Enable state storage
        "progress_bar": True     # Display progress bar
    }

    # Define time-dependent functions
    w0 = 2*pi*v_lattice
    dm = 2*pi*dm0
    w = w0+dm # lattice running frequency 
    
    def H1_time_cos(t, args): return np.cos(w * t)
    def H1_time_sin(t, args): return np.sin(w * t)
    
    # Full time-dependent Hamiltonian
    H = [H0+H1_const, [H1_cos, H1_time_cos], [H1_sin, H1_time_sin]]
    
    data = mesolve(H, mot0, tlist, c_ops=[], e_ops=[aIP_op.dag() * aIP_op, aOP_op.dag() * aOP_op], 
                   options=options) #phonon state as exp. value
    
    if check_norm_flag:
        # Check normalisation 
        final_state = data.states[-1]  # Extract the final state
        fock_probs = np.abs(final_state.full())**2  # Full state vector in array form
        fock_probs = fock_probs.flatten()
        print('(Running calculations:) The normalisation check:',sum(fock_probs))
        
    return data


def thermal_state_from_fock(N, nbar):
    """
    Constructs a thermal state as a sum of Fock states weighted by the Bose-Einstein distribution.
    """
    
    rho_thermal = Qobj(np.zeros((N, 1)))  # Initialize empty density matrix

    for n in range(N):
        P_n = (nbar ** n) / ((1 + nbar) ** (n + 1))  # Bose-Einstein probability
        fock_n = fock(N, n)  # |n⟩ Fock state
        rho_thermal += np.sqrt(P_n) * fock_n  # Weighted sum of Fock states
        # rho_thermal = rho_thermal.unit()
        
    return rho_thermal


def phonon_distribution_set(N, nbar, distribution='thermal'):
    """
    Constructs a state distribution
    """
    
    rho_state = Qobj(np.zeros((N, 1)))  # Initialize empty density matrix
    
    if distribution == 'thermal':
        for n in range(N):
            P_n = (nbar ** n) / ((1 + nbar) ** (n + 1))  # Bose-Einstein probability
            fock_n = fock(N, n)  # |n⟩ Fock state
            rho_state += np.sqrt(P_n) * fock_n  # Weighted sum of Fock states
        
    elif distribution == 'coherent':
        rho_state = coherent(N, np.sqrt(nbar))
    
    elif distribution == 'coherent2':
        use_sqrtPn = True 
        
        for n in range(N):
            if n >= N - 2:
                P_n = 0  # Force last two states to zero probability
            else:
                P_n = (nbar ** n) * np.exp(-nbar) / math.factorial(n)
            # P_n = (nbar ** n) * np.exp(-nbar) / math.factorial(n)
            fock_n = fock(N, n)  # |n⟩ Fock state
            if use_sqrtPn:
                rho_state += np.sqrt(P_n) * fock_n  # Weighted sum of Fock states
            else:
                #this is not correct
                rho_state += P_n * fock_n  # Weighted sum of Fock states
    
        if not use_sqrtPn:
            norm = np.sqrt(abs(rho_state.dag() * rho_state))  # Correct normalization
            print(norm)
            if 1 and norm > 0:
                rho_state = rho_state / norm  # Normalize
                
    
    normalization = (rho_state.dag() * rho_state)
    print(f"Normalization: {normalization}")

    return rho_state



# --- SAVING DATA OPTIONS --- 
data_folder_name = "generated_data"


# --- USER DEFINED PARAMETERS --- 
m1 = 40. #mass ion 1
m2 = 28. #mass ion 2
v1 = 0.62605 #[MHz] single-ion axial frequency

# Set up Hilbert spaces (First IP then OP)
NIP = 80 # Number of Fock states to include (needs to be large because population of the target mode increases!)
NOP = 2 # Number of Fock states to include (min. 2)

L_lattice = 0.7874505 #[um] lattice wavelength

dm0 = 0.000 # [MHz] detuning lattice 


m_max_cos_sin = 21 # maximum Tylor expansion element for sine/cosine in the Hamiltonian definition
Eac0 = 2*pi*0.016 #[MHz] ac-Stark shift from a single lattice laser on N2 only

# Initial populations of the modes
nIP_init = 0  
nOP_init = 0 #if ax-OP population is non-zero, update this value and increase the Fock space NOP!

choose_OP_phonon_dist = 'coherent' #spectator mode phonon distribution


tmax = 500 #[us] time span maximum for the simulations
tstep = 5e-2 #[us] step for simulations, important to optimise  
tpoints = math.ceil(tmax/tstep) # how many time points
print(f'(Settings:) Timestep: {tstep} us')
tlist = np.linspace(0,tmax,tpoints)


# --- PERFORM CALCULATIONS --- 
# Additional options:
#   - check if similar file exists. if so, update only part of the output (saves time!) or everything or don't overwrite. Request user to make choice or done automatically.
#   - it is possible to scan parameters in the main loop e.g. for NIP in np.arrange([60,70,80])


for _ in np.arange(1):
      
    # Set fock or coherent distribution on the OP
    if choose_OP_phonon_dist == 'fock':
        mot0 = tensor(fock(NIP,nIP_init), fock(NOP,nOP_init))
    elif choose_OP_phonon_dist == 'thermal':     
        mot0 = tensor(fock(NIP,nIP_init), thermal_state_from_fock(NOP,nOP_init))
    elif choose_OP_phonon_dist == 'coherent':
        mot0 = tensor(fock(NIP,nIP_init), coherent(NOP, np.sqrt(nOP_init)))
    
    elif choose_OP_phonon_dist == 'coherent2':
        mot0 = tensor(fock(NIP,nIP_init), phonon_distribution_set(NOP, nOP_init, distribution=choose_OP_phonon_dist))
    
    else:
        print('issue with the OP phonon dist!')
        break
    
    
    # There are two cases when user response requests 
    # 1. If a file exists, do you want to overwrite it? (auto_skipFlag in check_file_existance() fun.)
    # 2. If similar file exists, do you want use the calculated batch of data? (enable_user_responses if skip options)
    
    enable_user_responses = False
    default_response = 'no' # 'yes' == do not cover the same range if file exists
                            # 'no' == cover the same range even if similar file exists  
                            
    
    # Define rising and annihilation operators 
    aIP_op = tensor(destroy(NIP),qeye(NOP))
    aOP_op = tensor(qeye(NIP),destroy(NOP))
    

    # Create trap and lattice interaction hamilt    
    H0, x1, x2, vIP, vOP = create_trap_Hamiltonian(m1,m2,v1, aIP_op, aOP_op)        
    H1_cos, H1_sin, H1_const = create_lattice_interaction_hamiltonian(L_lattice, x2, Eac0, 
                                                            m_max_cos_sin)  
    
    # Frequency of the running lattice
    v_lattice = vIP

    # Make parameters dictionary for file naming of plots 
    exp_params_dict = {
        "m1": m1,
        "m2": m2, 
        "v1": (v1, 'kHz'),
        "NIP": NIP,
        "NOP": NOP,
        "Llattice": (L_lattice, 'nm'),
        "vlattice": (round(v_lattice,3), 'kHz'),
        "dm0": (dm0, 'kHz'),
        "Eac0": (Eac0/2/pi, 'kHz'),
        "nIPinit": nIP_init, 
        "nOPinit": nOP_init, 
        "tmax": (tmax, 'us'),
        "tstep": (tstep*1e3, 'ns'),
        "OPphDist": choose_OP_phonon_dist
    }
    
    
    skip_handles = ['m1', 'm2','L_lattice']  #skip handles in the filename
    
    file_exists_bool, file_path = check_file_existance('q2', exp_params_dict, data_folder_name, 
                                                       skip_handles=skip_handles, auto_skipFlag=True, auto_def_response='yes')   
                                                                               # auto_def_response: 'yes' -- overwrite the file if exists
                                                                               #                    'no' -- do NOT overwrite the file if exists
      
    if not file_exists_bool:
        
        # Combine with an existing file. Use this functionality to not repeat tedious calculations
        # with long timespans. 
        # Potential issues:
        # - if fock spaces differ there would be an issue with calculating exp. values
        
        def check_if_file_differs_by_one_param(file_path1, file_path2, data_folder_name, exclude_key, skip_handles=[]):
            # Generate parameter dictionaries for comparison
            params1 = filename_to_params_dict(file_path1, data_folder_name, skip_handles=skip_handles)
            params2 = filename_to_params_dict(file_path2, data_folder_name, skip_handles=skip_handles)
            
            exclude_keys = set([exclude_key] + skip_handles)
            return {k: v for k, v in params1.items() if k not in exclude_keys} == \
                   {k: v for k, v in params2.items() if k not in exclude_keys}
                   
        def get_file_with_max_tmax(files):
            # Extract the max tmax from filenames
            tmax_files = [(file, int(match.group(1))) for file in files if (match := re.search(r'tmax_(\d+)us', file))]
            return max(tmax_files, key=lambda x: x[1])[0] 
    
    
        # Make a list with files in directories of names starting with 'generated_data'
        files = [os.path.join(dirpath, f).replace('./', '')
                 for dirpath, _, filenames in os.walk('.')
                 if os.path.basename(dirpath).startswith('data_folder_name')
                 for f in filenames]
        
        # List files that match  
        list_matching_files = [file for file in files
                               if check_if_file_differs_by_one_param(file_path, file, data_folder_name, 'tmax', skip_handles)]
    
        
        if list_matching_files:
            # If there is a similar file - should it taken to reduce calculation times?
            if enable_user_responses:
                response = input('Found similar files. Use the existing data? (yes/no)?')
            else:
                response = default_response
                
            
            if response == 'yes':
                filepath_maxtmax = get_file_with_max_tmax(list_matching_files) # by default take the longest tmax file
                
                f1 = os.path.splitext(filepath_maxtmax)[0]            
                
                data_p1 = qload(f1) # get data from the existing file
                tlist_p1 = np.array(data_p1.times)
                tlist_p2 = tlist[tlist > tlist_p1.max()]
                mot0_p2 = data_p1.states[-1] # final state of p1 to start calculations in p2
                
                # Solve SE for the remaining data
                if tlist_p2.size != 0:
                    data_p2 = run_mesolve(H0, H1_const, H1_cos, H1_sin, mot0_p2, tlist_p2, vIP, dm0, aIP_op, aOP_op, check_norm_flag=True)
        
                    # Construct a new result object with combined data and save the data
                    data = Result(e_ops=data_p2.e_ops, options=data_p2.options) 
                    data.times = np.concatenate((data_p1.times, data_p2.times))
                    data.states = np.concatenate((data_p1.states, data_p2.states))  
                    save_qutip_data(data, 'q2', exp_params_dict, data_folder_name, skip_handles=skip_handles)
                else:
                    print('Exact file already exists!')
                    # sys.exit()
                    data = data_p1
        
        if not list_matching_files or response != 'yes':     
            # Solve the SE if there is no existing file or if calculate everything again
            data = run_mesolve(H0, H1_const, H1_cos, H1_sin, mot0, tlist, vIP, dm0, aIP_op, aOP_op, check_norm_flag=True)    
            save_qutip_data(data, 'q2', exp_params_dict, data_folder_name, skip_handles=skip_handles)            
    
    else:
        print("(File:) Skipped repeated calculations.")
        
        # Load the data from memory if exists
        data = qload(os.path.splitext(file_path)[0])
    
    
    
    ### ---- DATA PLOTTING --- ####
    
    # Define folder to save figures
    figure_save_folder_path = data_folder_name
    
    # Extract parameters to be used in legends
    params_dict = filename_to_params_dict(file_path, figure_save_folder_path, skip_handles=skip_handles)
    
    # If checking phonon-distribution at different timepoints, define point numbers below. 
    timepoint_list_points=[1000,5000,9000]
    
    if True:
        plot_expectation_vs_time(data, params_dict, figure_save_folder_path, 
                                      savePlotFlag=True, newFigureFlag=False, skip_handles=skip_handles,
                                      legend_params =['nOPinit', ''], timepoint_list=timepoint_list_points, mode_plt=0)
        
    if True:
        plot_phonon_distribution(data, params_dict, figure_save_folder_path,
                                     timepoint_list=timepoint_list_points, savePlotFlag=True, 
                                     newFigureFlag=False, skip_handles=skip_handles, title_params=['Eac0'], mode_plt=0)